# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math

import tensorrt as trt

from ..._common import default_net
from ..._utils import pad_vocab_size, str_dtype_to_trt
from ...functional import (RotaryScalingType, Tensor, gather_last_token_logits,
                           gpt_attention, partial, recv, send, unary)
from ...layers import (AttentionMaskType, AttentionParams, ColumnLinear,
                       Embedding, GatedMLP, KeyValueCacheParams,
                       PositionEmbeddingType, RmsNorm, RowLinear)
from ...mapping import Mapping
from ...module import Module, ModuleList
from ...parameter import Parameter
from ...quantization import QuantMode
from ...quantization.layers import FP8Linear, FP8RowLinear
from ..generation_mixin import GenerationMixin

log = partial(unary, op=trt.UnaryOperation.LOG)
ceil = partial(unary, op=trt.UnaryOperation.CEIL)


class QWenAttention(Module):

    def __init__(
        self,
        hidden_size,
        num_attention_heads,
        max_position_embeddings,
        seq_length,  # 2048
        num_kv_heads=None,
        num_layers=1,
        apply_query_key_layer_scaling=False,
        attention_mask_type=AttentionMaskType.causal,
        bias=True,
        dtype=None,
        position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
        rotary_embedding_base=10000.0,
        rotary_embedding_scaling=None,
        neox_rotary_style=False,
        use_int8_kv_cache=False,
        rotary_embedding_percentage=1.0,
        tp_group=None,
        tp_size=1,
        quant_mode: QuantMode = QuantMode(0),
        q_scaling=1.0,
        cross_attention=False,
        relative_attention=False,
        max_distance=0,
        num_buckets=0,
        instance_id: int = 0,
        use_dynamic_ntk=True,
        use_logn_attn=True,
    ):
        super().__init__()
        self.cross_attention = cross_attention
        self.seq_length = seq_length
        self.hidden_size = hidden_size
        self.num_heads = num_attention_heads

        self.attention_mask_type = attention_mask_type
        self.bias = bias
        self.attention_head_size = hidden_size // num_attention_heads
        self.num_attention_heads = num_attention_heads // tp_size
        self.num_attention_kv_heads = (
            num_kv_heads + tp_size - 1
        ) // tp_size if num_kv_heads is not None else self.num_attention_heads
        self.hidden_size = hidden_size // tp_size
        self.max_position_embeddings = max_position_embeddings

        self.num_layers = num_layers
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.norm_factor = math.sqrt(self.attention_head_size)
        self.q_scaling = q_scaling
        if self.apply_query_key_layer_scaling:
            self.norm_factor *= self.num_layers
            self.q_scaling *= self.num_layers

        self.position_embedding_type = position_embedding_type

        self.relative_attention = relative_attention
        self.max_distance = max_distance

        self.rotary_embedding_base = rotary_embedding_base
        self.rotary_embedding_scale_type = RotaryScalingType.none
        self.rotary_embedding_scale = 1.0
        if rotary_embedding_scaling is not None:
            assert rotary_embedding_scaling["type"] in ["linear", "dynamic"]
            self.rotary_embedding_scale_type = RotaryScalingType.linear if rotary_embedding_scaling[
                "type"] == "linear" else RotaryScalingType.dynamic
            self.rotary_embedding_scale = rotary_embedding_scaling["factor"]
            assert self.rotary_embedding_scale > 1.0
        self.rotary_embedding_dim = 0
        self.neox_rotary_style = neox_rotary_style
        if self.position_embedding_type == PositionEmbeddingType.rope_gpt_neox:
            self.rotary_embedding_dim = int(self.attention_head_size *
                                            rotary_embedding_percentage)

        self.dtype = dtype
        self.quant_mode = quant_mode
        if use_int8_kv_cache:
            # TODO: remove use_int8_kv_cache as can be replaced by quant_mode.has_kv_cache_quant()
            # Merge int8 setting into quant_mode
            self.quant_mode = self.quant_mode.set_int8_kv_cache()

        self.use_int8_kv_cache = use_int8_kv_cache
        if self.use_int8_kv_cache:
            self.kv_orig_quant_scale = Parameter(shape=(1, ), dtype='float32')
            self.kv_quant_orig_scale = Parameter(shape=(1, ), dtype='float32')
        else:
            self.register_parameter('kv_orig_quant_scale', None)
            self.register_parameter('kv_quant_orig_scale', None)

        self.use_fp8_qdq = self.quant_mode.has_fp8_qdq()
        if self.use_fp8_qdq:
            self.qkv = FP8Linear(hidden_size,
                                 hidden_size +
                                 (2 * tp_size * self.num_attention_kv_heads *
                                  self.attention_head_size),
                                 bias=True,
                                 dtype=dtype,
                                 tp_group=tp_group,
                                 tp_size=tp_size,
                                 gather_output=False)
            self.dense = FP8RowLinear(hidden_size,
                                      hidden_size,
                                      bias=bias,
                                      dtype=dtype,
                                      tp_group=tp_group,
                                      tp_size=tp_size,
                                      instance_id=instance_id)
        else:
            self.qkv = ColumnLinear(hidden_size,
                                    hidden_size +
                                    (2 * tp_size * self.num_attention_kv_heads *
                                     self.attention_head_size),
                                    bias=True,
                                    dtype=dtype,
                                    tp_group=tp_group,
                                    tp_size=tp_size,
                                    gather_output=False)
            self.dense = RowLinear(hidden_size,
                                   hidden_size,
                                   bias=bias,
                                   dtype=dtype,
                                   tp_group=tp_group,
                                   tp_size=tp_size,
                                   instance_id=instance_id)

        if relative_attention:
            self.rel_attn_table = Parameter(shape=(num_attention_heads //
                                                   tp_size, num_buckets),
                                            dtype=dtype)

        self.use_dynamic_ntk = use_dynamic_ntk
        self.use_logn_attn = use_logn_attn

    def forward(
        self,
        hidden_states: Tensor,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        workspace=None,
    ):
        if not default_net().plugin_config.gpt_attention_plugin:
            raise ValueError('QWen is only supported with GPTAttention plugin')

        assert isinstance(hidden_states, Tensor)
        qkv = self.qkv(hidden_states)

        kv_orig_quant_scale = self.kv_orig_quant_scale.value if self.use_int8_kv_cache else None
        kv_quant_orig_scale = self.kv_quant_orig_scale.value if self.use_int8_kv_cache else None

        # return outputs
        context, past_key_value = gpt_attention(
            tensor=qkv,
            past_key_value=kv_cache_params.get_first_past_key_value(),
            sequence_length=attention_params.sequence_length,
            host_past_key_value_lengths=kv_cache_params.
            host_past_key_value_lengths,
            host_max_kv_cache_lengths=kv_cache_params.host_max_kv_cache_lengths,
            context_lengths=attention_params.context_lengths,
            cache_indirection=kv_cache_params.cache_indirection,
            host_request_types=attention_params.host_request_types,
            num_heads=self.num_attention_heads,
            num_kv_heads=self.num_attention_kv_heads,
            hidden_size_per_head=self.attention_head_size,
            q_scaling=self.q_scaling,
            rotary_embedding_dim=self.
            rotary_embedding_dim,  # when we use it 0, we will not use rotary embedding in plugin
            rotary_embedding_scale_type=self.neox_rotary_style,
            rotary_embedding_max_positions=self.max_position_embeddings,
            position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
            kv_orig_quant_scale=kv_orig_quant_scale,
            kv_quant_orig_scale=kv_quant_orig_scale,
            kv_cache_quant_mode=QuantMode.from_description(
                use_int8_kv_cache=self.use_int8_kv_cache),
            kv_cache_block_pointers=kv_cache_params.
            get_first_kv_cache_block_pointers(),
            max_context_length=attention_params.max_context_length,
            mask_type=self.attention_mask_type.value,
            host_context_lengths=attention_params.host_context_lengths)

        context = self.dense(context, workspace=workspace)

        if use_cache:
            return (context, past_key_value)
        else:
            return context


class QWenBlock(Module):

    def __init__(self,
                 layer_id,
                 hidden_size,
                 seq_length,
                 num_attention_heads,
                 max_position_embeddings,
                 num_layers,
                 dtype=None,
                 attention_mask_type=AttentionMaskType.causal,
                 apply_query_key_layer_scaling=False,
                 hidden_act='silu',
                 position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
                 rotary_base=10000.0,
                 rotary_scaling=None,
                 quant_mode=QuantMode(0),
                 mlp_hidden_size=None,
                 neox_rotary_style=True,
                 bias=False,
                 tp_group=None,
                 tp_size=1,
                 rms_norm_eps=1e-06):
        super().__init__()
        self._layer_id = layer_id  # useful for debugging
        self.hidden_size = hidden_size
        self.seq_length = seq_length
        self.mlp_hidden_size = mlp_hidden_size
        self.neox_rotary_style = neox_rotary_style
        self.bias = bias
        self.hidden_act = hidden_act
        self.dtype = dtype
        self.attention_mask_type = attention_mask_type
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.tp_group = tp_group
        self.tp_size = tp_size
        self.num_attention_heads = num_attention_heads
        self.max_position_embeddings = max_position_embeddings
        self.num_layers = num_layers
        self.position_embedding_type = position_embedding_type

        self.ln_1 = RmsNorm(normalized_shape=hidden_size,
                            eps=rms_norm_eps,
                            dtype=dtype)

        self.attention = QWenAttention(
            hidden_size=self.hidden_size,
            num_attention_heads=self.num_attention_heads,
            max_position_embeddings=self.max_position_embeddings,
            num_layers=self.num_layers,
            seq_length=self.seq_length,
            dtype=self.dtype,
            attention_mask_type=self.attention_mask_type,
            bias=bias,
            position_embedding_type=self.position_embedding_type,
            rotary_embedding_base=rotary_base,
            rotary_embedding_scaling=rotary_scaling,
            neox_rotary_style=neox_rotary_style,
            tp_group=self.tp_group,
            tp_size=self.tp_size,
            use_int8_kv_cache=quant_mode.has_int8_kv_cache(),
        )
        if not mlp_hidden_size:
            mlp_hidden_size = hidden_size * 4

        self.mlp = GatedMLP(hidden_size=hidden_size,
                            ffn_hidden_size=mlp_hidden_size // 2,
                            hidden_act=hidden_act,
                            dtype=dtype,
                            bias=False,
                            tp_group=tp_group,
                            tp_size=tp_size,
                            quant_mode=quant_mode,
                            instance_id=2 * layer_id + 1)
        self.ln_2 = RmsNorm(normalized_shape=hidden_size,
                            eps=rms_norm_eps,
                            dtype=dtype)

    def forward(
        self,
        hidden_states: Tensor,
        use_cache=False,
        kv_cache_params=None,
        attention_params=None,
        all_reduce_workspace=None,
    ):
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attention_output = self.attention(
            hidden_states,
            use_cache=use_cache,
            kv_cache_params=kv_cache_params,
            attention_params=attention_params,
            workspace=all_reduce_workspace,
        )
        if use_cache:
            attention_output, presents = attention_output

        hidden_states = residual + attention_output

        residual = hidden_states

        hidden_states = self.ln_2(hidden_states)

        hidden_states = self.mlp(hidden_states)

        hidden_states = residual + hidden_states
        if use_cache:
            return (hidden_states, presents)
        return hidden_states


class QWenModel(Module):

    def __init__(
        self,
        num_layers,
        num_heads,
        hidden_size,
        seq_length,
        vocab_size,
        hidden_act,
        max_position_embeddings,
        dtype,
        mlp_hidden_size=None,
        position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
        neox_rotary_style=True,
        bias=False,
        rotary_base=10000.0,
        rotary_scaling=None,
        mapping=Mapping(),
        quant_mode=QuantMode(0),
        use_parallel_embedding=False,
        embedding_sharding_dim=0,
        rms_norm_eps=1e-06,
    ):
        super().__init__()
        self.mapping = mapping
        if self.mapping.is_first_pp_rank():
            self.vocab_embedding = Embedding(
                num_embeddings=vocab_size,
                embedding_dim=hidden_size,
                dtype=dtype,
                tp_size=mapping.tp_size if use_parallel_embedding else 1,
                tp_group=mapping.tp_group if use_parallel_embedding else None,
                sharding_dim=embedding_sharding_dim,
                tp_rank=mapping.tp_rank)

        self.layers = ModuleList([
            QWenBlock(layer_id=i,
                      hidden_size=hidden_size,
                      seq_length=seq_length,
                      num_attention_heads=num_heads,
                      num_layers=num_layers,
                      max_position_embeddings=max_position_embeddings,
                      dtype=dtype,
                      hidden_act=hidden_act,
                      quant_mode=quant_mode,
                      mlp_hidden_size=mlp_hidden_size,
                      position_embedding_type=position_embedding_type,
                      rotary_base=rotary_base,
                      rotary_scaling=rotary_scaling,
                      neox_rotary_style=neox_rotary_style,
                      bias=bias,
                      tp_group=mapping.tp_group,
                      tp_size=mapping.tp_size,
                      rms_norm_eps=rms_norm_eps)
            for i in self.get_transformer_layers(self.mapping, num_layers)
        ])

        self.ln_f = RmsNorm(normalized_shape=hidden_size,
                            eps=rms_norm_eps,
                            dtype=dtype)

    def forward(self,
                input_ids,
                position_ids=None,
                use_cache=False,
                kv_cache_params=None,
                attention_params=None,
                hidden_states=None,
                all_reduce_workspace=None):

        if kv_cache_params.past_key_value is None:
            tuple([None] * len(self.layers))

        kv_cache_params.fill_none_tensor_list(len(self.layers))

        if use_cache:
            presents = []

        if self.mapping.is_first_pp_rank():
            hidden_states = self.vocab_embedding(input_ids)
        else:
            hidden_states = recv(hidden_states, self.mapping.prev_pp_rank())
        self.register_network_output(f"embd", hidden_states)

        for layer, past, pointer, max_kv_cache_length in zip(
                self.layers, kv_cache_params.past_key_value,
                kv_cache_params.kv_cache_block_pointers,
                kv_cache_params.host_max_kv_cache_lengths):
            hidden_states = layer(
                hidden_states,
                use_cache=use_cache,
                kv_cache_params=KeyValueCacheParams(
                    past_key_value=[past],
                    host_past_key_value_lengths=kv_cache_params.
                    host_past_key_value_lengths,
                    host_max_kv_cache_lengths=max_kv_cache_length,
                    kv_cache_block_pointers=[pointer],
                    cache_indirection=kv_cache_params.cache_indirection),
                attention_params=attention_params,
                all_reduce_workspace=all_reduce_workspace)

            if use_cache:
                presents.append(hidden_states[1])
                hidden_states = hidden_states[0]

        if self.mapping.is_last_pp_rank():
            hidden_states = self.ln_f(hidden_states)
        else:
            hidden_states = send(hidden_states, self.mapping.next_pp_rank())

        if use_cache:
            return (hidden_states, tuple(presents))
        return hidden_states


class QWenForCausalLM(QWenModel, GenerationMixin):

    def __init__(
        self,
        num_layers,
        num_heads,
        num_kv_heads,
        hidden_size,
        seq_length,
        vocab_size,
        hidden_act,
        max_position_embeddings,
        dtype,
        logits_dtype="float32",
        mlp_hidden_size=None,
        neox_rotary_style=True,
        rotary_base=10000.0,
        rotary_scaling=None,
        mapping=Mapping(),
        quant_mode=QuantMode(0),
        use_parallel_embedding=False,
        embedding_sharding_dim=0,
        rms_norm_eps=1e-06,
    ):
        self.mapping = mapping
        if isinstance(dtype, str):
            self.dtype = str_dtype_to_trt(dtype)
        else:
            assert isinstance(dtype, trt.DataType)
            self.dtype = dtype
        if isinstance(logits_dtype, str):
            self.logits_dtype = str_dtype_to_trt(logits_dtype)
        else:
            assert isinstance(logits_dtype, trt.DataType)
            self.logits_dtype = logits_dtype
        self.num_layers = num_layers
        self.num_heads = num_heads
        if num_kv_heads is None or num_kv_heads <= 0:
            num_kv_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.tp_size = mapping.tp_size

        self.kv_dtype = self.dtype
        if quant_mode.has_int8_kv_cache():
            self.kv_dtype = str_dtype_to_trt('int8')
        elif quant_mode.has_fp8_kv_cache():
            self.kv_dtype = str_dtype_to_trt('fp8')
        self.quant_mode = quant_mode
        self.use_parallel_embedding = use_parallel_embedding
        self.embedding_sharding_dim = embedding_sharding_dim

        super().__init__(num_layers=num_layers,
                         num_heads=num_heads,
                         hidden_size=hidden_size,
                         seq_length=seq_length,
                         vocab_size=vocab_size,
                         hidden_act=hidden_act,
                         max_position_embeddings=max_position_embeddings,
                         dtype=dtype,
                         mlp_hidden_size=mlp_hidden_size,
                         neox_rotary_style=neox_rotary_style,
                         rotary_base=rotary_base,
                         rotary_scaling=rotary_scaling,
                         mapping=mapping,
                         quant_mode=quant_mode,
                         use_parallel_embedding=use_parallel_embedding,
                         embedding_sharding_dim=embedding_sharding_dim,
                         rms_norm_eps=rms_norm_eps)
        vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size)
        if self.mapping.is_last_pp_rank():
            self.lm_head = ColumnLinear(hidden_size,
                                        vocab_size_padded,
                                        bias=False,
                                        dtype=dtype,
                                        tp_group=mapping.tp_group,
                                        tp_size=mapping.tp_size,
                                        gather_output=True)

    def forward(self,
                input_ids,
                position_ids=None,
                use_cache=False,
                last_token_ids=None,
                kv_cache_params=None,
                attention_params=None,
                hidden_states=None,
                all_reduce_workspace=None):
        hidden_states = super().forward(input_ids, position_ids, use_cache,
                                        kv_cache_params, attention_params,
                                        hidden_states, all_reduce_workspace)
        if use_cache:
            hidden_states, presents = hidden_states

        if self.mapping.is_last_pp_rank():
            hidden_states = gather_last_token_logits(
                hidden_states, last_token_ids,
                default_net().plugin_config.remove_input_padding)

            # [batch_size, hidden_size] -> [batch_size, vocab_size]
            lm_logits = self.lm_head(hidden_states)
            lm_logits.mark_output('logits', self.logits_dtype)
        else:
            hidden_states.mark_output('hidden_states_output', self.dtype)

        if use_cache and default_net().plugin_config.paged_kv_cache == False:
            for i, present in zip(
                    self.get_transformer_layers(self.mapping, self.num_layers),
                    presents):
                present.mark_output(f'present_key_value_{i}', self.kv_dtype)
            if self.mapping.is_last_pp_rank():
                return (lm_logits, presents)
            return (hidden_states, presents)
        else:
            if self.mapping.is_last_pp_rank():
                return lm_logits
            return hidden_states

    def prepare_inputs(
        self,
        max_batch_size,
        max_input_len,
        max_new_tokens,
        use_cache,
        max_beam_width: int = 1,
        max_num_tokens: int = None,
    ):
        '''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the
            ranges of the dimensions of when using TRT dynamic shapes.

            @return: a list contains values which can be fed into the self.forward()
        '''

        # Prepare inputs
        head_size = self.hidden_size // self.num_heads
        remove_input_padding = default_net().plugin_config.remove_input_padding
        use_gpt_attention_plugin = default_net(
        ).plugin_config.gpt_attention_plugin
        use_gemm_plugin = default_net().plugin_config.gemm_plugin
        paged_kv_cache = default_net().plugin_config.paged_kv_cache
        tokens_per_block = default_net().plugin_config.tokens_per_block
        use_custom_all_reduce = default_net(
        ).plugin_config.use_custom_all_reduce

        model_inputs = self.prepare_basic_inputs(
            max_batch_size,
            max_beam_width,
            max_input_len,
            max_new_tokens,
            self.num_kv_heads,
            head_size,
            self.num_layers,
            self.kv_dtype,
            remove_input_padding=remove_input_padding,
            use_gpt_attention_plugin=use_gpt_attention_plugin,
            use_gemm_plugin=use_gemm_plugin,
            use_custom_all_reduce=use_custom_all_reduce,
            paged_kv_cache=paged_kv_cache,
            tokens_per_block=tokens_per_block,
            dtype=self.dtype,
            num_heads=self.num_heads,
            mapping=self.mapping,
            max_num_tokens=max_num_tokens,
        )

        return (model_inputs['input_ids'], model_inputs['position_ids'], True,
                model_inputs['last_token_ids'],
                KeyValueCacheParams(
                    past_key_value=model_inputs['past_key_value'],
                    host_past_key_value_lengths=model_inputs[
                        'host_past_key_value_lengths'],
                    host_max_kv_cache_lengths=model_inputs[
                        'host_max_kv_cache_lengths'],
                    kv_cache_block_pointers=model_inputs[
                        'kv_cache_block_pointers_list'],
                    cache_indirection=model_inputs['cache_indirection'],
                ),
                AttentionParams(
                    sequence_length=model_inputs['sequence_length'],
                    context_lengths=model_inputs['context_lengths'],
                    host_context_lengths=model_inputs['host_context_lengths'],
                    max_context_length=max_input_len,
                    host_request_types=model_inputs['host_request_types']),
                model_inputs['hidden_states_input'],
                model_inputs['all_reduce_workspace'])
